{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# DataLoader 和 Dataset\n",
    "\n",
    "构建模型的基本方法，我们了解了。接下来，我们就要弄明白怎么对数据进行预处理，然后加载数据，我们以前手动加载数据的方式，在数据量小的时候，并没有太大问题，但是到了大数据量，我们需要使用 shuffle, 分割成mini-batch 等操作的时候，我们可以使用PyTorch的API快速地完成这些操作。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch.utils.data import Dataset, DataLoader, TensorDataset\n",
    "from torch.autograd import Variable\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "xy = np.loadtxt('../dataSet/diabetes.csv.gz', delimiter=',', dtype=np.float32) # 使用numpy读取数据\n",
    "x_data = torch.from_numpy(xy[:, 0:-1])\n",
    "y_data = torch.from_numpy(xy[:, [-1]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([759, 8]) torch.Size([759, 1])\n"
     ]
    }
   ],
   "source": [
    "print(x_data.shape, y_data.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Dataset是一个包装类，用来将数据包装为Dataset类，然后传入DataLoader中，我们再使用DataLoader这个类来更加快捷的对数据进行操作。\n",
    "\n",
    "DataLoader是一个比较重要的类，它为我们提供的常用操作有：batch_size(每个batch的大小), shuffle(是否进行shuffle操作), num_workers(加载数据的时候使用几个子进程)\n",
    "\n",
    "现在，我们先展示直接使用 TensorDataset 来将数据包装成Dataset类"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "deal_dataset = TensorDataset(data_tensor=x_data, target_tensor=y_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_loader = DataLoader(dataset=deal_dataset,\n",
    "                          batch_size=32,\n",
    "                          shuffle=True,\n",
    "                          num_workers=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 0 inputs torch.Size([32, 8]) labels torch.Size([32, 1])\n",
      "0 1 inputs torch.Size([32, 8]) labels torch.Size([32, 1])\n",
      "0 2 inputs torch.Size([32, 8]) labels torch.Size([32, 1])\n",
      "0 3 inputs torch.Size([32, 8]) labels torch.Size([32, 1])\n",
      "0 4 inputs torch.Size([32, 8]) labels torch.Size([32, 1])\n",
      "0 5 inputs torch.Size([32, 8]) labels torch.Size([32, 1])\n",
      "0 6 inputs torch.Size([32, 8]) labels torch.Size([32, 1])\n",
      "0 7 inputs torch.Size([32, 8]) labels torch.Size([32, 1])\n",
      "0 8 inputs torch.Size([32, 8]) labels torch.Size([32, 1])\n",
      "0 9 inputs torch.Size([32, 8]) labels torch.Size([32, 1])\n",
      "0 10 inputs torch.Size([32, 8]) labels torch.Size([32, 1])\n",
      "0 11 inputs torch.Size([32, 8]) labels torch.Size([32, 1])\n",
      "0 12 inputs torch.Size([32, 8]) labels torch.Size([32, 1])\n",
      "0 13 inputs torch.Size([32, 8]) labels torch.Size([32, 1])\n",
      "0 14 inputs torch.Size([32, 8]) labels torch.Size([32, 1])\n",
      "0 15 inputs torch.Size([32, 8]) labels torch.Size([32, 1])\n",
      "0 16 inputs torch.Size([32, 8]) labels torch.Size([32, 1])\n",
      "0 17 inputs torch.Size([32, 8]) labels torch.Size([32, 1])\n",
      "0 18 inputs torch.Size([32, 8]) labels torch.Size([32, 1])\n",
      "0 19 inputs torch.Size([32, 8]) labels torch.Size([32, 1])\n",
      "0 20 inputs torch.Size([32, 8]) labels torch.Size([32, 1])\n",
      "0 21 inputs torch.Size([32, 8]) labels torch.Size([32, 1])\n",
      "0 22 inputs torch.Size([32, 8]) labels torch.Size([32, 1])\n",
      "0 23 inputs torch.Size([23, 8]) labels torch.Size([23, 1])\n",
      "1 0 inputs torch.Size([32, 8]) labels torch.Size([32, 1])\n",
      "1 1 inputs torch.Size([32, 8]) labels torch.Size([32, 1])\n",
      "1 2 inputs torch.Size([32, 8]) labels torch.Size([32, 1])\n",
      "1 3 inputs torch.Size([32, 8]) labels torch.Size([32, 1])\n",
      "1 4 inputs torch.Size([32, 8]) labels torch.Size([32, 1])\n",
      "1 5 inputs torch.Size([32, 8]) labels torch.Size([32, 1])\n",
      "1 6 inputs torch.Size([32, 8]) labels torch.Size([32, 1])\n",
      "1 7 inputs torch.Size([32, 8]) labels torch.Size([32, 1])\n",
      "1 8 inputs torch.Size([32, 8]) labels torch.Size([32, 1])\n",
      "1 9 inputs torch.Size([32, 8]) labels torch.Size([32, 1])\n",
      "1 10 inputs torch.Size([32, 8]) labels torch.Size([32, 1])\n",
      "1 11 inputs torch.Size([32, 8]) labels torch.Size([32, 1])\n",
      "1 12 inputs torch.Size([32, 8]) labels torch.Size([32, 1])\n",
      "1 13 inputs torch.Size([32, 8]) labels torch.Size([32, 1])\n",
      "1 14 inputs torch.Size([32, 8]) labels torch.Size([32, 1])\n",
      "1 15 inputs torch.Size([32, 8]) labels torch.Size([32, 1])\n",
      "1 16 inputs torch.Size([32, 8]) labels torch.Size([32, 1])\n",
      "1 17 inputs torch.Size([32, 8]) labels torch.Size([32, 1])\n",
      "1 18 inputs torch.Size([32, 8]) labels torch.Size([32, 1])\n",
      "1 19 inputs torch.Size([32, 8]) labels torch.Size([32, 1])\n",
      "1 20 inputs torch.Size([32, 8]) labels torch.Size([32, 1])\n",
      "1 21 inputs torch.Size([32, 8]) labels torch.Size([32, 1])\n",
      "1 22 inputs torch.Size([32, 8]) labels torch.Size([32, 1])\n",
      "1 23 inputs torch.Size([23, 8]) labels torch.Size([23, 1])\n"
     ]
    }
   ],
   "source": [
    "for epoch in range(2):\n",
    "    for i, data in enumerate(train_loader):\n",
    "        # 将数据从 train_loader 中读出来,一次读取的样本数是32个\n",
    "        inputs, labels = data\n",
    "\n",
    "        # 将这些数据转换成Variable类型\n",
    "        inputs, labels = Variable(inputs), Variable(labels)\n",
    "\n",
    "        # 接下来就是跑模型的环节了，我们这里使用print来代替\n",
    "        print(epoch, i, \"inputs\", inputs.data.size(), \"labels\", labels.data.size())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "接下来，我们来继承 Dataset类 ，写一个将数据处理成DataLoader的类。\n",
    "\n",
    "当我们集成了一个 Dataset类之后，我们需要重写 \\_\\_len\\_\\_ 方法，该方法提供了dataset的大小；  \\_\\_getitem\\_\\_ 方法， 该方法支持从 0 到 len(self)的索引"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch： 0 的第 0 个inputs torch.Size([32, 8]) labels torch.Size([32, 1])\n",
      "epoch： 0 的第 1 个inputs torch.Size([32, 8]) labels torch.Size([32, 1])\n",
      "epoch： 0 的第 2 个inputs torch.Size([32, 8]) labels torch.Size([32, 1])\n",
      "epoch： 0 的第 3 个inputs torch.Size([32, 8]) labels torch.Size([32, 1])\n",
      "epoch： 0 的第 4 个inputs torch.Size([32, 8]) labels torch.Size([32, 1])\n",
      "epoch： 0 的第 5 个inputs torch.Size([32, 8]) labels torch.Size([32, 1])\n",
      "epoch： 0 的第 6 个inputs torch.Size([32, 8]) labels torch.Size([32, 1])\n",
      "epoch： 0 的第 7 个inputs torch.Size([32, 8]) labels torch.Size([32, 1])\n",
      "epoch： 0 的第 8 个inputs torch.Size([32, 8]) labels torch.Size([32, 1])\n",
      "epoch： 0 的第 9 个inputs torch.Size([32, 8]) labels torch.Size([32, 1])\n",
      "epoch： 0 的第 10 个inputs torch.Size([32, 8]) labels torch.Size([32, 1])\n",
      "epoch： 0 的第 11 个inputs torch.Size([32, 8]) labels torch.Size([32, 1])\n",
      "epoch： 0 的第 12 个inputs torch.Size([32, 8]) labels torch.Size([32, 1])\n",
      "epoch： 0 的第 13 个inputs torch.Size([32, 8]) labels torch.Size([32, 1])\n",
      "epoch： 0 的第 14 个inputs torch.Size([32, 8]) labels torch.Size([32, 1])\n",
      "epoch： 0 的第 15 个inputs torch.Size([32, 8]) labels torch.Size([32, 1])\n",
      "epoch： 0 的第 16 个inputs torch.Size([32, 8]) labels torch.Size([32, 1])\n",
      "epoch： 0 的第 17 个inputs torch.Size([32, 8]) labels torch.Size([32, 1])\n",
      "epoch： 0 的第 18 个inputs torch.Size([32, 8]) labels torch.Size([32, 1])\n",
      "epoch： 0 的第 19 个inputs torch.Size([32, 8]) labels torch.Size([32, 1])\n",
      "epoch： 0 的第 20 个inputs torch.Size([32, 8]) labels torch.Size([32, 1])\n",
      "epoch： 0 的第 21 个inputs torch.Size([32, 8]) labels torch.Size([32, 1])\n",
      "epoch： 0 的第 22 个inputs torch.Size([32, 8]) labels torch.Size([32, 1])\n",
      "epoch： 0 的第 23 个inputs torch.Size([23, 8]) labels torch.Size([23, 1])\n",
      "epoch： 1 的第 0 个inputs torch.Size([32, 8]) labels torch.Size([32, 1])\n",
      "epoch： 1 的第 1 个inputs torch.Size([32, 8]) labels torch.Size([32, 1])\n",
      "epoch： 1 的第 2 个inputs torch.Size([32, 8]) labels torch.Size([32, 1])\n",
      "epoch： 1 的第 3 个inputs torch.Size([32, 8]) labels torch.Size([32, 1])\n",
      "epoch： 1 的第 4 个inputs torch.Size([32, 8]) labels torch.Size([32, 1])\n",
      "epoch： 1 的第 5 个inputs torch.Size([32, 8]) labels torch.Size([32, 1])\n",
      "epoch： 1 的第 6 个inputs torch.Size([32, 8]) labels torch.Size([32, 1])\n",
      "epoch： 1 的第 7 个inputs torch.Size([32, 8]) labels torch.Size([32, 1])\n",
      "epoch： 1 的第 8 个inputs torch.Size([32, 8]) labels torch.Size([32, 1])\n",
      "epoch： 1 的第 9 个inputs torch.Size([32, 8]) labels torch.Size([32, 1])\n",
      "epoch： 1 的第 10 个inputs torch.Size([32, 8]) labels torch.Size([32, 1])\n",
      "epoch： 1 的第 11 个inputs torch.Size([32, 8]) labels torch.Size([32, 1])\n",
      "epoch： 1 的第 12 个inputs torch.Size([32, 8]) labels torch.Size([32, 1])\n",
      "epoch： 1 的第 13 个inputs torch.Size([32, 8]) labels torch.Size([32, 1])\n",
      "epoch： 1 的第 14 个inputs torch.Size([32, 8]) labels torch.Size([32, 1])\n",
      "epoch： 1 的第 15 个inputs torch.Size([32, 8]) labels torch.Size([32, 1])\n",
      "epoch： 1 的第 16 个inputs torch.Size([32, 8]) labels torch.Size([32, 1])\n",
      "epoch： 1 的第 17 个inputs torch.Size([32, 8]) labels torch.Size([32, 1])\n",
      "epoch： 1 的第 18 个inputs torch.Size([32, 8]) labels torch.Size([32, 1])\n",
      "epoch： 1 的第 19 个inputs torch.Size([32, 8]) labels torch.Size([32, 1])\n",
      "epoch： 1 的第 20 个inputs torch.Size([32, 8]) labels torch.Size([32, 1])\n",
      "epoch： 1 的第 21 个inputs torch.Size([32, 8]) labels torch.Size([32, 1])\n",
      "epoch： 1 的第 22 个inputs torch.Size([32, 8]) labels torch.Size([32, 1])\n",
      "epoch： 1 的第 23 个inputs torch.Size([23, 8]) labels torch.Size([23, 1])\n"
     ]
    }
   ],
   "source": [
    "class DealDataset(Dataset):\n",
    "    \"\"\"\n",
    "        下载数据、初始化数据，都可以在这里完成\n",
    "    \"\"\"\n",
    "    def __init__(self):\n",
    "        xy = np.loadtxt('../dataSet/diabetes.csv.gz', delimiter=',', dtype=np.float32) # 使用numpy读取数据\n",
    "        self.x_data = torch.from_numpy(xy[:, 0:-1])\n",
    "        self.y_data = torch.from_numpy(xy[:, [-1]])\n",
    "        self.len = xy.shape[0]\n",
    "    \n",
    "    def __getitem__(self, index):\n",
    "        return self.x_data[index], self.y_data[index]\n",
    "\n",
    "    def __len__(self):\n",
    "        return self.len\n",
    "\n",
    "# 实例化这个类，然后我们就得到了Dataset类型的数据，记下来就将这个类传给DataLoader，就可以了。    \n",
    "dealDataset = DealDataset()\n",
    "\n",
    "train_loader2 = DataLoader(dataset=dealDataset,\n",
    "                          batch_size=32,\n",
    "                          shuffle=True)\n",
    "\n",
    "\n",
    "for epoch in range(2):\n",
    "    for i, data in enumerate(train_loader2):\n",
    "        # 将数据从 train_loader 中读出来,一次读取的样本数是32个\n",
    "        inputs, labels = data\n",
    "\n",
    "        # 将这些数据转换成Variable类型\n",
    "        inputs, labels = Variable(inputs), Variable(labels)\n",
    "\n",
    "        # 接下来就是跑模型的环节了，我们这里使用print来代替\n",
    "        print(\"epoch：\", epoch, \"的第\" , i, \"个inputs\", inputs.data.size(), \"labels\", labels.data.size())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "# torchvision 包的介绍\n",
    "\n",
    "torchvision 是PyTorch中专门用来处理图像的库，PyTorch官网的安装教程，也会让你安装上这个包。\n",
    "\n",
    "这个包中有四个大类。\n",
    "\n",
    "torchvision.datasets\n",
    "\n",
    "torchvision.models\n",
    "\n",
    "torchvision.transforms\n",
    "\n",
    "torchvision.utils\n",
    "\n",
    "这里我们主要介绍前三个。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "## torchvision.datasets\n",
    "\n",
    "torchvision.datasets 是用来进行数据加载的，PyTorch团队在这个包中帮我们提前处理好了很多很多图片数据集。\n",
    "\n",
    "\n",
    "* MNIST\n",
    "* COCO\n",
    "* Captions\n",
    "* Detection\n",
    "* LSUN\n",
    "* ImageFolder\n",
    "* Imagenet-12\n",
    "* CIFAR\n",
    "* STL10\n",
    "* SVHN\n",
    "* PhotoTour\n",
    "\n",
    "我们可以直接使用，示例如下："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torchvision\n",
    "\n",
    "DOWNLOAD = True\n",
    "\n",
    "trainset = torchvision.datasets.MNIST(root='./data', # 表示 MNIST 数据的加载的相对目录\n",
    "                                      train=True,  # 表示是否加载数据库的训练集，false的时候加载测试集\n",
    "                                      download=DOWNLOAD, # 表示是否自动下载 MNIST 数据集\n",
    "                                      transform=None) # 表示是否需要对数据进行预处理，none为不进行预处理\n",
    "\n",
    "# 上面代码就完成了MNIST数据 训练集的加载环节，\n",
    "train_loader2 = DataLoader(dataset=trainset,\n",
    "                          batch_size=32,\n",
    "                          shuffle=True)\n",
    "print(\"训练集总长度为：\" , len(trainset))\n",
    "print(\"每个mini-batch的size 为 32 , 一共有：\" , len(train_loader2) , \"个\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## torchvision.models\n",
    "\n",
    "torchvision.models 中为我们提供了已经训练好的模型，让我们可以加载之后，直接使用。\n",
    "\n",
    "torchvision.models模块的 子模块中包含以下模型结构。\n",
    "* AlexNet\n",
    "* VGG\n",
    "* ResNet\n",
    "* SqueezeNet\n",
    "* DenseNet\n",
    "\n",
    "我们可以直接使用如下代码来快速创建一个权重随机初始化的模型\n",
    "\n",
    "    import torchvision.models as models\n",
    "    resnet18 = models.resnet18()\n",
    "    alexnet = models.alexnet()\n",
    "    squeezenet = models.squeezenet1_0()\n",
    "    densenet = models.densenet_161()\n",
    "    \n",
    "也可以通过使用 pretrained=True 来加载一个别人预训练好的模型\n",
    "\n",
    "    import torchvision.models as models\n",
    "    resnet18 = models.resnet18(pretrained=True)\n",
    "    alexnet = models.alexnet(pretrained=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torchvision.models as models\n",
    "# 加载一个 resnet18 模型\n",
    "resnet18 = models.resnet18()\n",
    "print(resnet18)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torchvision.models as models\n",
    "resnet18 = models.resnet18(pretrained=True) # 加载一个已经预训练好的模型， 需要下载一段时间... "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## torchvision.transforms\n",
    "\n",
    "transforms 模块提供了一般的图像转换操作类。\n",
    "\n",
    "举两个例子\n",
    "\n",
    "    class torchvision.transforms.ToTensor \n",
    "\n",
    "> 把shape=(H x W x C)的像素值范围为[0, 255]的PIL.Image或者numpy.ndarray转换成shape=(C x H x W)的像素值范围为[0.0, 1.0]的torch.FloatTensor。\n",
    "\n",
    "\n",
    "    class torchvision.transforms.Normalize(mean, std) \n",
    "\n",
    "> 给定均值(R, G, B)和标准差(R, G, B)，用公式channel = (channel - mean) / std进行规范化。\n",
    "\n",
    "当然，也可以使用 \n",
    "    \n",
    "    class torchvision.transforms.RandomCrop(size, padding=0)\n",
    "    \n",
    "或者\n",
    "\n",
    "    class torchvision.transforms.RandomSizedCrop(size, interpolation=2)\n",
    "    \n",
    "来进行图像的 augment\n",
    "\n",
    "然后如果需要同时进行这些操作，我们可以使用一个\n",
    "\n",
    "    class torchvision.transforms.Compose(transforms)\n",
    "\n",
    "来把多个transform组合起来使用。\n",
    "\n",
    "如下所示:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 我们这里还是对MNIST进行处理，初始的MNIST是 28 * 28，我们把它处理成 96 * 96 的torch.Tensor的格式\n",
    "from torchvision import transforms as transforms\n",
    "import torchvision\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "# 图像预处理步骤\n",
    "transform = transforms.Compose([\n",
    "    transforms.Resize(96), # 缩放到 96 * 96 大小\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # 归一化\n",
    "])\n",
    "\n",
    "DOWNLOAD = True\n",
    "BATCH_SIZE = 32\n",
    "\n",
    "train_dataset = torchvision.datasets.MNIST(root='./data/', train=True, transform=transform, download=DOWNLOAD)\n",
    "\n",
    "\n",
    "train_loader = DataLoader(dataset=train_dataset,\n",
    "                          batch_size=BATCH_SIZE,\n",
    "                          shuffle=True)\n",
    "\n",
    "print(len(train_dataset))\n",
    "print(len(train_loader))"
   ]
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
